#!/usr/bin/env python
# coding=utf-8
"""
Build a cropped dataset from the WikiArt dataset.
This script reads images from local WikiArt dataset folders,
creates random crops from them, and saves them to a new directory
structured into train, validation, and test sets.
"""

import os
import random
import argparse
import concurrent.futures
from typing import List, Tuple

import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm

# 有些图像文件可能已损坏，设置此项可以增强PIL的容错能力
ImageFile.LOAD_TRUNCATED_IMAGES = True

def get_all_image_paths(root_dir: str) -> List[str]:
    """
    Get all image paths in a directory and its subdirectories, filtering for .jpg.
    This is adapted for WikiArt's structure where images are in style subfolders.
    """
    image_paths = []
    if not os.path.isdir(root_dir):
        print(f"Error: Source directory {root_dir} not found.")
        return image_paths

    print(f"Scanning for images in {root_dir}...")
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.lower().endswith('.jpg'):
                image_paths.append(os.path.join(dirpath, filename))

    # 返回排序后的列表，以确保在设置相同随机种子时初始顺序一致
    return sorted(image_paths)

def random_crop(image_path: str, crop_size: int) -> np.ndarray:
    """Create a random crop of size crop_size x crop_size from an image."""
    try:
        # 使用 with 来确保文件句柄被正确关闭
        with Image.open(image_path) as img:
            image = img.convert('RGB')
    except (IOError, OSError) as e:
        print(f"Warning: Could not open image {image_path}, skipping. Error: {e}")
        return None

    width, height = image.size

    # 如果图像的任何一边小于裁剪尺寸，则进行放大处理
    if width < crop_size or height < crop_size:
        # 计算缩放比例，并增加一点余量（5%），防止裁剪时出现边缘问题
        scale = max(crop_size / width, crop_size / height) * 1.05
        new_width, new_height = int(width * scale), int(height * scale)
        image = image.resize((new_width, new_height), Image.LANCZOS)
        width, height = image.size

    # 计算随机裁剪的左上角坐标
    left = random.randint(0, width - crop_size)
    top = random.randint(0, height - crop_size)

    # 执行裁剪
    crop = image.crop((left, top, left + crop_size, top + crop_size))
    return np.array(crop)

def process_and_save_image(args: Tuple[str, str, int]) -> bool:
    """
    Process a single image: create a random crop and save it.
    Returns True on success, False on failure.
    """
    source_image_path, save_path, crop_size = args

    # 创建随机裁剪
    crop_array = random_crop(source_image_path, crop_size)

    if crop_array is None:
        return False

    # 保存裁剪后的图像
    try:
        # 确保目标目录存在
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        Image.fromarray(crop_array).save(save_path, 'PNG')
        return True
    except Exception as e:
        print(f"Error saving crop to {save_path}: {e}")
        return False


def create_crops_for_split(
    all_source_paths: List[str],
    output_dir: str,
    crop_size: int,
    num_crops: int,
    split_name: str
) -> None:
    """
    Create a specified number of random crops for a data split (train/val/test).
    Crops are generated by randomly selecting images from the provided master list.
    """
    print(f"\nProcessing split: {split_name}")
    print(f"Destination: {output_dir}")
    print(f"Number of crops to generate: {num_crops}")

    os.makedirs(output_dir, exist_ok=True)

    # 创建任务列表
    # 每个任务包含一个随机选择的源图像路径、目标保存路径和裁剪尺寸
    tasks = []
    for i in range(num_crops):
        # 从所有可用的源图像中随机选择一张（允许重复使用）
        source_img_path = random.choice(all_source_paths)
        # 格式化输出文件名，从0开始并补零为8位
        save_path = os.path.join(output_dir, f"{i:08d}.png")
        tasks.append((source_img_path, save_path, crop_size))

    # 使用线程池并行处理图像，并显示进度条
    with concurrent.futures.ThreadPoolExecutor() as executor:
        list(tqdm(
            executor.map(process_and_save_image, tasks),
            total=len(tasks),
            desc=f"Creating {split_name} crops"
        ))

def main():
    """Main function to build the cropped WikiArt dataset."""
    parser = argparse.ArgumentParser(description="Build a cropped dataset from WikiArt")
    # 修改默认路径以匹配您的环境和数据集
    parser.add_argument("--source_dir", type=str,
                        default="",
                        help="Base source directory of the WikiArt dataset")
    parser.add_argument("--output_dir", type=str,
                        default="",
                        help="Base output directory for cropped images")
    parser.add_argument("--crop_size", type=int, default=256,
                        help="Size of the crops (n x n)")
    parser.add_argument("--train_crops", type=int, default=10000,
                        help="Number of training crops to generate")
    parser.add_argument("--val_crops", type=int, default=5000,
                        help="Number of validation crops to generate")
    parser.add_argument("--test_crops", type=int, default=5000,
                        help="Number of test crops to generate")
    parser.add_argument("--seed", type=int, default=42,
                        help="Random seed for reproducibility")

    args = parser.parse_args()

    # 设置随机种子以保证结果可复现
    random.seed(args.seed)
    np.random.seed(args.seed)

    print("Starting WikiArt dataset cropping process...")
    print(f"Source directory: {args.source_dir}")
    print(f"Output directory: {args.output_dir}")

    # 1. 收集所有图片路径
    all_image_paths = get_all_image_paths(args.source_dir)
    if not all_image_paths:
        print("Fatal: No images found in the source directory. Exiting.")
        return
    print(f"Found a total of {len(all_image_paths)} source images.")

    # 2. 定义要创建的数据集划分
    # 格式: (目标子文件夹名, 裁剪数量)
    splits_to_process = [
        ('train', args.train_crops),
        ('val', args.val_crops),
        ('test', args.test_crops)
    ]

    # 3. 为每个划分生成裁剪图
    for dest_folder, num_crops in splits_to_process:
        if num_crops <= 0:
            print(f"Skipping '{dest_folder}' as number of crops is set to 0.")
            continue

        full_output_dir = os.path.join(args.output_dir, dest_folder)

        create_crops_for_split(
            all_source_paths=all_image_paths, # 传递所有图片的路径列表
            output_dir=full_output_dir,
            crop_size=args.crop_size,
            num_crops=num_crops,
            split_name=dest_folder
        )

    print(f"\nDone! Cropped WikiArt dataset saved in: {args.output_dir}")

if __name__ == "__main__":
    main()